# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Callable, Hashable, Iterable, Sequence
from typing import Any, TypeVar

version: int = ...

_T = TypeVar("_T")

_Children = TypeVar("_Children", bound=Iterable[Any])

_KeyLeafPair = TypeVar("_KeyLeafPair", bound=tuple[Any, Any])

_KeyLeafPairs = TypeVar("_KeyLeafPairs", bound=Iterable[tuple[Any, Any]])

_KeyPath = TypeVar("_KeyPath", bound=tuple[Any, ...])

_AuxData = TypeVar("_AuxData", bound=Hashable)

class PyTreeRegistry:
  def __init__(
      self,
      enable_none: bool = ...,
      enable_tuple: bool = ...,
      enable_namedtuple: bool = ...,
      enable_list: bool = ...,
      enable_dict: bool = ...,
  ) -> None: ...
  def flatten(
      self,
      tree: object | None,
      leaf_predicate: Callable[[Any], bool] | None = None,
  ) -> tuple[list[Any], PyTreeDef]: ...
  def flatten_one_level(
      self, tree: object | None
  ) -> tuple[Iterable[Any], Any] | None: ...
  def flatten_one_level_with_keys(
      self, tree: object | None
  ) -> tuple[Iterable[_KeyLeafPair], Any] | None: ...
  def flatten_with_path(
      self,
      tree: object | None,
      leaf_predicate: Callable[[Any, Any], bool] | None = None,
  ) -> tuple[list[tuple[_KeyPath, Any]], PyTreeDef]: ...
  def register_node(
      self,
      type: type[_T],
      to_iterable: Callable[[_T], tuple[_Children, _AuxData]],
      from_iterable: Callable[[_AuxData, _Children], _T],
      to_iterable_with_keys: (
          Callable[[_T], tuple[_KeyLeafPairs, _AuxData]] | None
      ) = None,
  ) -> Any: ...
  def register_dataclass_node(
      self,
      type: type,
      data_fields: Sequence[str],
      meta_fields: Sequence[str],
      /,
  ) -> Any: ...
  def __reduce__(self) -> str: ...

_default_registry: PyTreeRegistry = ...

def default_registry() -> PyTreeRegistry: ...
def treedef_tuple(
    registry: PyTreeRegistry, arg0: Sequence[PyTreeDef], /
) -> PyTreeDef: ...
def all_leaves(arg0: PyTreeRegistry, arg1: Iterable, /) -> bool: ...

class PyTreeDef:
  def unflatten(self, arg: Iterable[Any], /) -> Any: ...
  def flatten_up_to(self, tree: object | None) -> list: ...
  def compose(self, arg: PyTreeDef, /) -> PyTreeDef: ...
  def walk(
      self,
      __f_node: Callable[[Any, Any], Any],
      __f_leaf: Callable[[_T], Any] | None,
      leaves: Iterable[Any],
      /,
  ) -> Any:
    """Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf at leaves"""

  def from_iterable_tree(self, arg: object, /) -> object: ...
  def children(self) -> list[PyTreeDef]: ...
  @property
  def num_leaves(self) -> int: ...
  @property
  def num_nodes(self) -> int: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __ne__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...
  def serialize_using_proto(self) -> bytes: ...
  @staticmethod
  def deserialize_using_proto(
      registry: PyTreeRegistry, data: bytes
  ) -> PyTreeDef: ...
  def node_data(self) -> tuple[type, Any] | None:
    """Returns None if a leaf-pytree, else (type, node_data)"""

  @staticmethod
  def from_node_data_and_children(
      self,
      registry: PyTreeRegistry,
      node_data: tuple[type, Any] | None,
      children: Iterable[PyTreeDef],
  ) -> PyTreeDef:
    """Reconstructs a pytree from `node_data()` and `children()`."""

  def __getstate__(self) -> object: ...
  def __setstate__(self, arg: object, /) -> None: ...

class SequenceKey(Hashable):
  def __init__(self, idx: int) -> None: ...
  def __str__(self) -> str: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...
  @property
  def idx(self) -> int: ...

  __match_args__: tuple = ...
  """(arg: object, /) -> tuple"""

  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...

class DictKey(Hashable):
  def __init__(self, key: object) -> None: ...
  def __str__(self) -> str: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...
  @property
  def key(self) -> object: ...

  __match_args__: tuple = ...
  """(arg: object, /) -> tuple"""

  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...

class GetAttrKey(Hashable):
  def __init__(self, name: str) -> None: ...
  def __str__(self) -> str: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...
  @property
  def name(self) -> str: ...

  __match_args__: tuple = ...
  """(arg: object, /) -> tuple"""

  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...

class FlattenedIndexKey(Hashable):
  def __init__(self, key: int) -> None: ...
  def __str__(self) -> str: ...
  def __repr__(self) -> str: ...
  def __eq__(self, arg: object, /) -> bool: ...
  def __hash__(self) -> int: ...
  @property
  def key(self) -> int: ...

  __match_args__: tuple = ...
  """(arg: object, /) -> tuple"""

  def __getstate__(self) -> tuple: ...
  def __setstate__(self, arg: tuple, /) -> None: ...
